import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler
import numpy as np
import copy
import random

#bootstrap_bs = 5
#BatchNum = 100
#NumSample = 20
#B = 50
seed = 2
npseed = 6
torchseed = 5

def LossScaledTrace(test_model, train_data, d, train_size, B=3200):
    #torch.manual_seed(seed)
    model = copy.deepcopy(test_model)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    #model = nn.DataParallel(model)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1)

    # Create DataLoader
    dataset = train_data
    if B > train_size:
        print("Error: The number of samples cannot be larger than the train set size!")
    else:
        sampler = RandomSampler(dataset, replacement=False, num_samples=B)
        #batch_loader = DataLoader(dataset=dataset, batch_size=bs, sampler=sampler)
        single_loader = DataLoader(dataset=dataset, batch_size=1, sampler=sampler)
        #full_loader = DataLoader(dataset=dataset, batch_size=N_train)
    # Initialize the three matrices
    Q = torch.zeros(d, B)
    Q = Q.to(device)
    S = torch.zeros(d, B)
    S = S.to(device)
    P = torch.zeros(B, B)
    P = P.to(device)
    
    # Construct the three matrices
    optimizer.zero_grad()
    for idx, (image, label) in enumerate(single_loader):
    #for idx, (x, y) in enumerate(batch_loader):
        #torch.cuda.empty_cache()
        image = Variable(image)
        label = Variable(label)
        image = image.to(device)
        label = label.to(device)
        output = model(image)
        output = output.to(torch.float32)
        loss = criterion(output, label)
        loss.backward()
        Gradient = torch.empty(0, 1)
        Gradient = Gradient.to(device)
        for p in model.parameters():
            if p.requires_grad:
                Gradient = torch.cat((Gradient, p.grad.view(-1, 1)), 0)
        Q[:,idx] = torch.flatten(Gradient)
        optimizer.zero_grad()

        #print(torch.matmul(Gradients[idx], torch.transpose(Gradients[idx])).size())
    #return torch.trace(torch.mm(Hessian, CovarianceMatrix)).cpu().numpy() / max([(FullLoss * 2), 10e-15]), \
    #       torch.norm(Hessian, p='fro').cpu().numpy() / 1, torch.trace(Hessian).cpu().numpy() / 1